#!/usr/bin/env python
# coding: utf-8
# test sync in gd

from utils_new import *
import datetime
from tqdm import tqdm
import argparse
import scipy
import multiprocessing as mp
from models import *
from samplers import *
from scipy import stats

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)


parser = argparse.ArgumentParser(description='Training GCN on Cora/PPI/PubMed/Reddit Datasets')

'''
    Dataset arguments
'''
parser.add_argument('--dataset', type=str, default='cora',
                    help='Dataset name: cora/citeseer/pubmed/reddit/ppi/ppi-large')
parser.add_argument('--nhid', type=int, default=64,
                    help='Hidden state dimension')
parser.add_argument('--epoch_num', type=int, default= 100,
                    help='Number of Epoch')
parser.add_argument('--pool_num', type=int, default= 10,
                    help='Number of Pool')
parser.add_argument('--batch_num', type=int, default= 10,
                    help='Maximum Batch Number')
parser.add_argument('--batch_size', type=int, default=128,
                    help='size of output node in a batch')
parser.add_argument('--n_layers', type=int, default=2,
                    help='Number of GCN layers')
parser.add_argument('--n_iters', type=int, default=1,
                    help='Number of iteration to run on a batch')
parser.add_argument('--n_stops', type=int, default=200,
                    help='Stop after number of batches that f1 dont increase')
parser.add_argument('--samp_num', type=int, default=128,
                    help='Number of sampled nodes per layer')
parser.add_argument('--cuda', type=int, default=0,
                    help='Avaiable GPU ID: -1 for cpu, 0 for gpu')
parser.add_argument('--n_trial', type=int, default=1,
                    help='Number of times to repeat experiments')
parser.add_argument('--record_f1', type=int, default=1,
                    help='Record the f1 score')
parser.add_argument('--full_valid', type=int, default=1,
                    help='Use all neighbors for validation')
parser.add_argument('--samp_growth_rate', type = float, default = 1,
                    help='Growth rate for layer-wise sampling')
args = parser.parse_args()

args = parser.parse_args()


if args.cuda != -1:
    device = torch.device("cuda:" + str(args.cuda))
else:
    device = torch.device("cpu")
    
# Load data 
print(args.dataset)
adj_matrix, labels, feat_data, train_nodes, valid_nodes, test_nodes = preprocess_data(args.dataset)

n_nodes = feat_data.shape[0]

data_save = (adj_matrix, labels, feat_data, train_nodes, valid_nodes, test_nodes)

print("n train, val, test")
print(len(train_nodes), len(valid_nodes), len(test_nodes))
print("batch_size: ", args.batch_size, ", sample_size: ", args.samp_num)

# Get Laplacian (np.array): D^{-1/2} A_tilde D^{-1/2}
lap_matrix = normalize_lap(adj_matrix + sp.eye(adj_matrix.shape[0]))

if type(feat_data) == scipy.sparse.lil.lil_matrix:
    feat_data = torch.FloatTensor(feat_data.todense()).to(device) 
else:
    feat_data = torch.FloatTensor(feat_data).to(device)

#   Loss function:
#       BCEwithlogit for multi-label dataset
#       CrossEntropy for 1-label dataset
multi_label = True if args.dataset in ['ppi', 'ppi-large', 'yelp'] else False

if multi_label:
    loss_func = nn.BCEWithLogitsLoss()
    labels = torch.FloatTensor(labels).to(device)
    num_classes = labels.shape[1]
else:
    # loss_func = nn.CrossEntropyLoss()
    loss_func = F.cross_entropy
    labels    = torch.LongTensor(labels).to(device) 
    num_classes = labels.max().item()+1


sample_method_ls = ["fastgcn", "sketch", "ladies_wrs", "sketch_wrs", "ladies"]

write_file = "w"
original_stdout = sys.stdout
result_pkl = dict()
result_pkl["args"] = args
best_model_idx = str(datetime.datetime.now()).replace(' ', '_').replace(':', '.')

filename = "main_{}_lay_{}_{}".format(args.dataset, args.n_layers, best_model_idx)

for sample_method in sample_method_ls:

    if sample_method == 'ladies':
        sampler = ladies_sampler
    elif sample_method == 'fastgcn':
        sampler = fastgcn_sampler
    elif sample_method == 'full':
        sampler = full_batch_sampler
    elif sample_method in ['sketch']:
        sampler = sketch_sampler    
    elif sample_method == "ladies_wrs":
        sampler = ladies_sampler_wrs
    elif sample_method == "sketch_wrs":
        sampler = sketch_sampler_wrs

    print("Sampler: ", sample_method)
    

    # process_ids = np.arange(args.batch_num)
    process_ids = np.arange(10)
    n_iter = args.batch_num // 10

    samp_num_list = np.array(args.samp_num * args.samp_growth_rate ** np.arange(args.n_layers), dtype = int)
    print("Sampler: ", sample_method, "batch_size: ", args.batch_size, "batch_num: ",
          args.batch_num, "sample_num: ", samp_num_list)
    pool = mp.Pool(args.pool_num)
    jobs = prepare_data(pool, sampler, process_ids, train_nodes, valid_nodes, samp_num_list, 
                        len(feat_data), lap_matrix, args.n_layers, args.batch_size)

    # record information
    all_res = []
    total_time_all = []
    test_f1_all   = []
    epoch_time_all = []
    epoch_num = []
    valid_f1_all = []
    valid_f1_single_iter = []

    for oiter in range(args.n_trial):

        encoder = GCN(nfeat = feat_data.shape[1], nhid=args.nhid, 
                      layers=args.n_layers, dropout = 0.2).to(device)
        susage = SuGCN(encoder = encoder, num_classes=num_classes, dropout=0.2,
                          inp = feat_data.shape[1])
        susage.to(device)

        optimizer = optim.Adam(filter(lambda p : p.requires_grad, susage.parameters()))
        best_val, best_tst = -1, -1
        cnt = 0
        times = []
        res   = []
        epoch_time = []
        valid_f1_single_iter = []
        print('-' * 10)
        for epoch in np.arange(args.epoch_num):
            susage.train()
            train_losses = []
            '''
                Use CPU-GPU cooperation to reduce the overhead for sampling. 
                (conduct sampling while training)
            '''
            # train for one epoch
            for _iter in range(n_iter):
                
                train_data = [job.get() for job in jobs[:-1]]
                valid_data = jobs[-1].get()
                pool.close()
                pool.join()
                pool = mp.Pool(args.pool_num)
                jobs = prepare_data(pool, sampler, process_ids, train_nodes, valid_nodes,
                    samp_num_list, len(feat_data), lap_matrix, args.n_layers,
                    args.batch_size)                

                for adjs, input_nodes, output_nodes, after_nodes_ls in train_data:    
                    adjs = package_mxl(adjs, device)
                    optimizer.zero_grad()
                    
                    t1 = time.time()
                    susage.train()
                    output = susage.forward(feat_data[input_nodes], adjs)
                    
                    # use different losses with differnt tasks
                    # loss_train = F.cross_entropy(output, labels[output_nodes])
                    loss_train = loss_func(output, labels[output_nodes])
                    loss_train.backward()
                    # torch.nn.utils.clip_grad_norm_(susage.parameters(), 0.2)
                    optimizer.step()
                    
                    times += [time.time() - t1]
                    train_losses += [loss_train.detach().tolist()]
                    del loss_train

                print(np.sum(times))

            #   perform validation at the end of each epoch
            epoch_time += [np.sum(times)]
            susage.eval()
            adjs, input_nodes, output_nodes, after_nodes_ls = valid_data
            adjs = package_mxl(adjs, device)

            # For validation, sketch method does not update HW_row_norm
            output = susage.forward(feat_data[input_nodes], adjs)

            # if sample_method == 'full':
            #     output = output[output_nodes]
            # output = output[output_nodes]

            # loss_valid = F.cross_entropy(output, labels[output_nodes]).detach().tolist()
            loss_valid = loss_func(output, labels[output_nodes]).detach().tolist()
     

            #   calculate F1 score for multi_(0, 1)_label (classes are 0 or 1)
            valid_f1   = eval_f1(output, labels, output_nodes, num_classes, multi_label)

            print(("Epoch: %d (%.1fs) Train Loss: %.2f    Valid Loss: %.2f  Valid F1: %.3f") % \
                              (epoch, np.sum(times), np.average(train_losses), loss_valid, valid_f1))
            valid_f1_single_iter.append(valid_f1)

            if valid_f1 > best_val + 1e-2:
                best_val = valid_f1
                torch.save(susage, './save/best_model_{}.pt'.format(best_model_idx))
                cnt = 0
            else:
                cnt += 1
            if cnt == args.n_stops // args.batch_num:
                break
        best_model = torch.load('./save/best_model_{}.pt'.format(best_model_idx))
        best_model.eval()
        test_f1s = []


        '''
        If using full-batch inference for testing data:
        '''
        batch_nodes = test_nodes
        adjs, input_nodes, output_nodes, _ = full_batch_sampler(np.random.randint(2**32 - 1), batch_nodes,
                                                          samp_num_list, len(feat_data),
                                                          lap_matrix, args.n_layers)
        adjs = package_mxl(adjs, device)

        # For testing, sketch method does not update HW_row_norm
        # output = best_model.forward(feat_data[input_nodes], adjs)[output_nodes]
        output = best_model.forward(feat_data[input_nodes], adjs)
                
        test_f1 = eval_f1(output, labels, output_nodes, num_classes, multi_label)
        test_f1s = [test_f1]
        
        print('Iteration: %d, Test F1: %.3f' % (oiter, np.average(test_f1s)))


        total_time_all += [np.sum(times)]
        test_f1_all  += [test_f1]
        epoch_num += [epoch]
        epoch_time_all += [epoch_time]
        valid_f1_all += [valid_f1_single_iter]


    # record F1 score and training time
    if args.record_f1:
        txt_name = filename + '.txt'
        result_pkl[sample_method] = record_result(args, txt_name, total_time_all, samp_num_list, valid_f1_all,
                                                  test_f1_all, epoch_num, epoch_time_all, write_file,
                                                  sample_method, original_stdout)
        print(sample_method, "\'s information recorded")

# record .pkl
if args.record_f1:
    with open('./result/{}/{}.pkl'.format(args.dataset, filename),'wb') as f:
        pkl.dump(result_pkl, f)
    print("All information is recorded")
    
